package me.yricky.kaguya.dsl

import me.yricky.kaguya.base.*
import me.yricky.kaguya.base.lu.*

open class ModuleScope{
    internal val _inputs:MutableMap<String, NamedWire> = mutableMapOf()
    internal val _outputs:MutableMap<String, NamedSignal> = mutableMapOf()
    internal val _innerSignals:MutableMap<String, Signal> = mutableMapOf()
    internal val _logicUnits:MutableList<LogicUnit> = mutableListOf()

    fun <T: NamedSignal> T.exportAsOutput() :T{
        checkSignalNameNotExported(name)
        _innerSignals.remove(name)
        _outputs[name] = this
        return this
    }

    fun ExpressionWire.exportAsOutput(name:String) = assigned(name).exportAsOutput()

    fun Signal.inv():ExpressionWire{
        val lu = Inv(this)
        _logicUnits.add(lu)
        return lu.wire
    }

    fun NamedSignal.split(start:Int,length:Int = 1):ExpressionWire{
        return Slice(this,start,length).also {
            _logicUnits.add(it)
        }.wire
    }

    operator fun NamedSignal.get(i:Int):ExpressionWire = split(i)

    infix fun Signal.or(that:Signal):ExpressionWire{
        assert(width == that.width)
        return Or(listOf(this,that)).also {
            _logicUnits.add(it)
        }.wire
    }

    operator fun Signal.plus(that:Signal):ExpressionWire{
        assert(width == that.width)
        return Plus(Pair(this,that)).also {
            _logicUnits.add(it)
        }.wire
    }

    infix fun Signal.moreThan(that:Signal):ExpressionWire{
        assert(width == that.width)
        return MoreThan(Pair(this,that)).also {
            _logicUnits.add(it)
        }.wire
    }

    infix fun Signal.lessThan(that:Signal):ExpressionWire{
        assert(width == that.width)
        return LessThan(Pair(this,that)).also {
            _logicUnits.add(it)
        }.wire
    }

    fun Signal.or():ExpressionWire{
        return SOr(this).also {
            _logicUnits.add(it)
        }.wire
    }

    fun Signal.and():ExpressionWire{
        return SAnd(this).also {
            _logicUnits.add(it)
        }.wire
    }

    fun Signal.xor():ExpressionWire{
        return SXor(this).also {
            _logicUnits.add(it)
        }.wire
    }

    infix fun Signal.xor(that:Signal):ExpressionWire{
        assert(width == that.width)
        return Xor(listOf(this,that)).also {
            _logicUnits.add(it)
        }.wire
    }

    infix fun Signal.and(that:Signal):ExpressionWire{
        assert(width == that.width)
        return And(listOf(this,that)).also {
            _logicUnits.add(it)
        }.wire
    }

    infix fun Signal.eq(that:Signal):ExpressionWire{
        assert(width == that.width)
        return Equal(Pair(this,that)).also {
            _logicUnits.add(it)
        }.wire
    }

    infix fun NamedWire.assignTo(that:Signal) {
        assert(width == that.width)
        _logicUnits.add(Assign(that,this))
    }

    fun Signal.assigned(name: String):NamedWire {
        return wire(name,width).also {
            it.assignTo(this)
        }
    }

    fun Signal.repeat(count:Int):ExpressionWire{
        return Repeat(this,count).also {
            _logicUnits.add(it)
        }.wire
    }

    fun Module.impl(
        name:String,
        inputs: Map<String,NamedSignal>,
        outputs: Map<String,NamedWire>,
    ){
        _logicUnits.add(
            ModuleUnit(
            name,
            inputs,
            outputs,
            this
        )
        )
    }
}

fun module(
    name:String,
    builder: ModuleScope.()->Unit
): Module {
    val scope = ModuleScope()
    scope.apply(builder)
    return Module(
        name,
        Module.IO(
            scope._inputs.values.toList(),
            scope._outputs.values.toList()
        ),
        scope._innerSignals.values.toList(),
        scope._logicUnits
    )
}
